﻿// *********************************************************
// 
//     Copyright (c) Microsoft. All rights reserved.
//     This code is licensed under the Apache License, Version 2.0.
//     THIS CODE IS PROVIDED *AS IS* WITHOUT WARRANTY OF
//     ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING ANY
//     IMPLIED WARRANTIES OF FITNESS FOR A PARTICULAR
//     PURPOSE, MERCHANTABILITY, OR NON-INFRINGEMENT.
// 
// *********************************************************
using System.Collections.Generic;
using System.Text;
using System;

namespace Bio.Algorithms.Alignment.MultipleSequenceAlignment
{
    #region Enum
    /// <summary>
    /// Enum of distance functions that calculate a real number between 
    /// two kmer counting frequencies. 
    /// The two kmer frequencies are placed in a Euclidean Space as vectors, 
    /// and the distance can be defined in several ways enumerated below.
    /// 
    /// A full list of available functions can be found in paper:
    /// Alignment-free sequence comparison - a review, Vinga 2002.
    /// </summary>
    public enum DistanceFunctionTypes
    {
        /// <summary>
        /// Euclidean distance = sqrt(sum((c1-c2)^2))
        /// </summary>
        EuclideanDistance,

        /// <summary>
        /// Pearson Correlation
        /// </summary>
        PearsonCorrelation,

        /// <summary>
        /// Co-Variance
        /// </summary>
        CoVariance,

        /// <summary>
        /// Modified distance function used in MUSCLE
        /// </summary>
        ModifiedMUSCLE
    }
    #endregion

    /// <summary>
    /// The delegate function of distance function.
    /// The distance functions assign a real number to two kmer counting frequencies.
    /// 
    /// Given the same kmer length and the same alphabets, the two kmer counting frequency
    /// vectors are supposed to be the same in length. Due to the sparsity of the vectors,
    /// only the occuring kmers are stored in the Dictionary, thus only occuring kmers
    /// have a float frequency value assigned in the Dictionary, and those not occuring
    /// kmers have 0 frequency by default.
    /// 
    /// Some distance functions require the two vectors to be normalized, e.g. Euclidean distance,
    /// and some do not. Thus CalculateKmerCounting method returns un-normalized counts, and the 
    /// normalization is done in the distance functions as needed.
    /// </summary>
    /// <param name="countsDA">kmer counting dictionary generated by KmerCounting class</param>
    /// <param name="countsDB">kmer counting dictionary generated by KmerCounting class</param>
    delegate float DistanceFunctionSelector(Dictionary<String, float> countsDA,
                                            Dictionary<String, float> countsDB);

    /// <summary>
    /// Implementation of kmer distance score calculator class that
    /// calculates the distances between two sequences by kmer counting.
    /// 
    /// KmerCounting enumerates all the k-mers occuring in the input ISequence and counts 
    /// how many times it occurs.
    /// 
    /// Inputs: ISequence and kmer length.
    /// Methods: slides the k-length window along the sequence and count the how many times 
    /// each kmer occurs.
    /// 
    /// Given the sequence length L and alphabet size a, the possible number of kmers is a^k.
    /// But since the number of windows is L-k+1, the number of occuring k-mer is less than L-k+1.
    /// 
    /// When a^k >> L-k+1 (which is always the case for short protein, short DNA),
    /// *Dictionary* is used to store only occuring k-mer and its counts, so as not to
    /// list all possible a^k kmers.
    /// 
    /// Kmer counting frequency is calculated once for each sequence. The distance score
    /// between two sequences are then calculated by comparing the two frequency vectors.
    /// 
    /// </summary>
    public sealed class KmerDistanceScoreCalculator
    {
        #region Member Variables

        // The delegate function of distance function
        private DistanceFunctionSelector _distanceFunction;

        // The length of kmer in this class
        private int _kmerLength;

        // The number of possible kmers in this class
        private int _numberOfPossibleKmers;

        #endregion

        #region Constructors
        /// <summary>
        /// Default distance function is Euclidean distance
        /// </summary>
        public KmerDistanceScoreCalculator(int kmerLength, IAlphabet alphabetType)
            : this(kmerLength, alphabetType, DistanceFunctionTypes.EuclideanDistance)
        {
        }

        /// <summary>
        /// Construct a calculator with selected distance function
        /// 
        /// A distance function is assigned to the class and it is 
        /// read-only for a given set of input sequences.
        /// </summary>
        /// <param name="kmerLength">positive integer kmer length</param>
        /// <param name="alphabetType">molecule type: DNA, RNA or Protein</param>
        /// <param name="DistanceFunctionName">DistanceFunctionTypes member</param>
        public KmerDistanceScoreCalculator(int kmerLength, IAlphabet alphabetType, DistanceFunctionTypes DistanceFunctionName)
        {
            if (kmerLength <= 0)
            {
                throw new ArgumentException("Kmer length needs to be positive");
            }

            _kmerLength = kmerLength;

            if (alphabetType is DnaAlphabet)
            {
                _numberOfPossibleKmers = (int)Math.Pow(15, _kmerLength);
            }
            else if (alphabetType is RnaAlphabet)
            {
                _numberOfPossibleKmers = (int)Math.Pow(15, _kmerLength);
            }
            else if (alphabetType is ProteinAlphabet)
            {
                _numberOfPossibleKmers = (int)Math.Pow(25, _kmerLength);
            }
            else
            {
                throw new Exception("Invalid molecular type");
            }

            switch (DistanceFunctionName)
            {
                case (DistanceFunctionTypes.EuclideanDistance):
                    _distanceFunction = new DistanceFunctionSelector(EuclideanDistance);
                    break;
                case (DistanceFunctionTypes.CoVariance):
                    _distanceFunction = new DistanceFunctionSelector(CoVariance);
                    break;
                case (DistanceFunctionTypes.PearsonCorrelation):
                    _distanceFunction = new DistanceFunctionSelector(PearsonCorrelation);
                    break;
                case (DistanceFunctionTypes.ModifiedMUSCLE):
                    _distanceFunction = new DistanceFunctionSelector(ModifiedMUSCLE);
                    break;
                default:
                    throw new ArgumentException("Similarity Function Name is not in the list...");
            }
        }
        #endregion

        #region Methods
        /// <summary>
        /// Calculate distance score from two kmer counting dictionaries generated
        /// by CalculateKmerCounting method of this class.
        /// </summary>
        /// <param name="countsDA">kmer counting dictionary</param>
        /// <param name="countsDB">kmer counting dictionary</param>
        public float CalculateDistanceScore(Dictionary<String, float> countsDA,
                                            Dictionary<String, float> countsDB)
        {
            return _distanceFunction(countsDA, countsDB);
        }


        /// <summary>
        /// Slide the window along the sequence, and calculate kmer counts.
        /// Occuring kmers are represented as String, and the counts are stored 
        /// in Dictionary. 
        /// 
        /// The counts are raw numbers, and will then be converted to frequencies 
        /// (normalized) in the distance function if needed.
        /// </summary>
        /// <param name="seq">input sequence (unaligned)</param>
        /// <param name="kmerLength">positive integer kmer length</param>
        public static Dictionary<String, float> CalculateKmerCounting(ISequence seq, int kmerLength)
        {
            if (kmerLength <= 0)
            {
                throw new ArgumentException("Kmer length needs to be positive");
            }
            if (kmerLength > seq.Count)
            {
                throw new ArgumentException("Kmer length is larger than the sequence length.");
            }

            // Initialize countsDictionary
            Dictionary<String, float> countsDictionary = new Dictionary<String, float>();

            // StringBuilder 'kmer' stores the k-mer in each window.
            // When sliding the k-length window, new k-mer is generated 
            // by simply adding one item at the end and removing one in 
            // the front.
            StringBuilder kmer = new StringBuilder();
            
            for (int i = 0; i < kmerLength; ++i)
            {
                kmer.Append((char)seq[i]);
            }
            countsDictionary[kmer.ToString()] = 1;

            // Slide the window and add each kmer into countsDictionary
            for (int i = kmerLength; i < seq.Count; ++i)
            {
                // modify kmer for the new window
                kmer.Append((char)seq[i]);
                kmer.Remove(0, 1);

                String kmerString = kmer.ToString();

                // Add into countsDictionary
                if (countsDictionary.ContainsKey(kmerString))
                {
                    ++countsDictionary[kmerString];
                }
                else
                {
                    countsDictionary[kmerString] = 1;
                }
            }
            return countsDictionary;
        }
        #endregion

        
        #region Private methods
        // Check out Enum DistanceFunctionTypes and delegate DistanceFunctionSelector for details.
        private float EuclideanDistance(Dictionary<String, float> countsDA,
                                        Dictionary<String, float> countsDB)
        {
            float result = 0;

            foreach (var pair in countsDA)
            {
                if (countsDB.ContainsKey(pair.Key))
                {
                    result += (float)Math.Pow(countsDA[pair.Key] - countsDB[pair.Key], 2);
                }
                else
                {
                    result += (float)Math.Pow(countsDA[pair.Key], 2);
                }
            }
            foreach (var pair in countsDB)
            {
                if (!countsDA.ContainsKey(pair.Key))
                {
                    result += (float)Math.Pow(countsDB[pair.Key], 2);
                }
            }
            return (float)Math.Sqrt(result);
        }

        private float PearsonCorrelation(Dictionary<String, float> countsDA,
                                         Dictionary<String, float> countsDB)
        {
            float result = 0;

            float averageA = 0, standardDeviationA = 0, averageB = 0, standardDeviationB = 0;
            
            foreach (var pair in countsDA)
            {
                averageA += countsDA[pair.Key];
            }
            averageA /= _numberOfPossibleKmers;
            foreach (var pair in countsDA)
            {
                standardDeviationA += (float)Math.Pow(countsDA[pair.Key] - averageA, 2);
            }
            standardDeviationA /= _numberOfPossibleKmers;
            standardDeviationA = (float)Math.Sqrt(standardDeviationA);

            foreach (var pair in countsDB)
            {
                averageB += countsDB[pair.Key];
            }
            averageB /= _numberOfPossibleKmers;
            foreach (var pair in countsDB)
            {
                standardDeviationB += (float)Math.Pow(countsDB[pair.Key] - averageB, 2);
            }
            standardDeviationB /= _numberOfPossibleKmers;
            standardDeviationB = (float)Math.Sqrt(standardDeviationB);

            foreach (var pair in countsDA)
            {
                if (countsDB.ContainsKey(pair.Key))
                {
                    result += (countsDA[pair.Key] - averageA) * (countsDB[pair.Key] - averageB);
                }
                else
                {
                    result += (countsDA[pair.Key] - averageA) * (-averageB);
                }
            }
            foreach (var pair in countsDB)
            {
                if (!countsDA.ContainsKey(pair.Key))
                {
                    result += (-averageA) * (countsDB[pair.Key] - averageB);
                }
            }

            result = (float)(result / (standardDeviationA * standardDeviationB) / _numberOfPossibleKmers);
            return result;
        }

        private float CoVariance(Dictionary<String, float> countsDA,
                                 Dictionary<String, float> countsDB)
        {
            float result = 0;

            float averageA = 0, standardDeviationA = 0, averageB = 0, standardDeviationB = 0;

            foreach (var pair in countsDA)
            {
                averageA += countsDA[pair.Key];
            }
            averageA /= _numberOfPossibleKmers;
            foreach (var pair in countsDA)
            {
                standardDeviationA += (float)Math.Pow(countsDA[pair.Key] - averageA, 2);
            }
            standardDeviationA /= _numberOfPossibleKmers;
            standardDeviationA = (float)Math.Sqrt(standardDeviationA);

            foreach (var pair in countsDB)
            {
                averageB += countsDB[pair.Key];
            }
            averageB /= _numberOfPossibleKmers;
            foreach (var pair in countsDB)
            {
                standardDeviationB += (float)Math.Pow(countsDB[pair.Key] - averageB, 2);
            }
            standardDeviationB /= _numberOfPossibleKmers;
            standardDeviationB = (float)Math.Sqrt(standardDeviationB);

            foreach (var pair in countsDA)
            {
                if (countsDB.ContainsKey(pair.Key))
                {
                    result += (countsDA[pair.Key] - averageA) * (countsDB[pair.Key] - averageB);
                }
                else
                {
                    result += (countsDA[pair.Key] - averageA) * (-averageB);
                }
            }
            foreach (var pair in countsDB)
            {
                if (!countsDA.ContainsKey(pair.Key))
                {
                    result += (-averageA) * (countsDB[pair.Key] - averageB);
                }
            }

            result = (float)(result / _numberOfPossibleKmers);
            return result;
        }

        private float ModifiedMUSCLE(Dictionary<String, float> countsDA,
                                     Dictionary<String, float> countsDB)
        {
            float result = 0;

            foreach (var pair in countsDA)
            {
                if (countsDB.ContainsKey(pair.Key))
                {
                    result += (float)Math.Min(countsDA[pair.Key], countsDB[pair.Key]);
                }
            }
            return (float)Math.Sqrt(result);
        }
        #endregion
    }
}
